{{def_kernel("A", "B")}}
    M = {{size("A", 0)}}
    N = {{size("B", 1)}}
    K = {{size("A", 1)}}
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    start_pid = tl.program_id(0)
    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    k_tiles = tl.cdiv(K, BLOCK_K)
    num_tiles = grid_m * grid_n

    # Note: We require TMA_EXPERIMENTAL_API == False, which
    # we will check before invoking this template.
    stride_am = {{stride("A", 0)}}
    stride_ak = {{stride("A", 1)}}
    stride_bk = {{stride("B", 0)}}
    stride_bn = {{stride("B", 1)}}
    a_desc = triton.language.make_tensor_descriptor(
        base=A,
        shape=[M, K] if A_ROW_MAJOR else [K, M],
        strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1],
        block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
    )
    b_desc = triton.language.make_tensor_descriptor(
        base=B,
        shape=[K, N] if B_ROW_MAJOR else [N, K],
        strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1],
        block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
    )

    # tile_id_c is used in the epilogue to break the dependency between
    # the prologue and the epilogue
    tile_id_c = start_pid - NUM_SMS
    num_pid_in_group = GROUP_M * grid_n

    for tile_id in tl.range(
        start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE
    ):
        pid_m, pid_n = _compute_pid(
            tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
        )
        offs_am = pid_m * BLOCK_M
        offs_bn = pid_n * BLOCK_N

        accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
        for ki in range(k_tiles):
            offs_k = ki * BLOCK_K
            a = tl.load_tensor_descriptor(
                a_desc,
                [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am],
            )
            b = tl.load_tensor_descriptor(
                b_desc,
                [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k],
            )
            accumulator += tl.dot(
                a if A_ROW_MAJOR else a.T,
                b if B_ROW_MAJOR else b.T,
                allow_tf32=ALLOW_TF32,
            )

        tile_id_c += NUM_SMS
        pid_m, pid_n = _compute_pid(
            tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
        )
        offs_cm = pid_m * BLOCK_M
        offs_cn = pid_n * BLOCK_N
        {%- if EPILOGUE_SUBTILE %}
        tl.static_assert(BLOCK_N % 2 == 0)
        acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
        acc = tl.permute(acc, (0, 2, 1))
        acc0, acc1 = tl.split(acc)
        {{store_output(
            ("offs_cm", "offs_cn"),
            "acc0",
            indent_width=8,
            val_shape=("BLOCK_M", "BLOCK_N // 2"),
            block_indexing=True
        )}}
        offs_cn2 = offs_cn + BLOCK_N // 2
        {{store_output(
            ("offs_cm", "offs_cn2"),
            "acc1",
            indent_width=8,
            val_shape=("BLOCK_M", "BLOCK_N // 2"),
            block_indexing=True
        )}}
        {%- else %}
        {{store_output(
            ("offs_cm", "offs_cn"),
            "accumulator",
            indent_width=8,
            val_shape=("BLOCK_M", "BLOCK_N"),
            block_indexing=True
        )}}
        {%- endif %}

@triton.jit
def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr):
    group_id = tile_id // num_pid_in_group
    first_pid_m = group_id * GROUP_M
    GROUP_M = min(grid_m - first_pid_m, GROUP_M)
    pid_m = first_pid_m + (tile_id % GROUP_M)
    pid_n = (tile_id % num_pid_in_group) // GROUP_M
    return pid_m, pid_n
